import matplotlib.pyplot as plt
import numpy as np

np.random.seed(19680801)


fig, ax = plt.subplots()

x1_f= "Main_results/2_fig2/x1"
y1_f= "Main_results/2_fig2/y1"
x2_f= "Main_results/2_fig2/x2"
y2_f= "Main_results/2_fig2/y2"

arrays = []

for file in [x1_f,y1_f,x2_f,y2_f]:
    with open(file, 'r') as f:
        data = f.readlines()
    arrays.append(np.array([float(line.strip()) for line in data]))


x1,y1,x2,y2 = arrays

scale = 200

colors = ['tab:blue', 'tab:orange']

labels = ['ViT','OVIT']

ax.scatter(x1, y1, c=colors[0], s=scale, label=labels[0],
            alpha=0.3, edgecolors='none')

ax.scatter(x2, y2, c=colors[1], s=scale, label=labels[1],
            alpha=0.3, edgecolors='none')

ax.legend()
ax.grid(True)

plt.savefig('Main_results/2_fig2/2_scatters.png')
# plt.show()